# Copyright (c) 2023-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

header = """/*
 * Copyright (c) 2023-2024, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * NOTE: this file is generated by ivf_flat_00_generate.py
 *
 * Make changes there and run in this directory:
 *
 * > python ivf_flat_00_generate.py
 *
 */

#include <raft/neighbors/ivf_flat-inl.cuh>
"""

types = dict(
    float_int64_t=("float", "int64_t"),
    int8_t_int64_t=("int8_t", "int64_t"),
    uint8_t_int64_t=("uint8_t", "int64_t"),
)

build_macro = """
#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT)        \\
  template auto raft::neighbors::ivf_flat::build<T, IdxT>(        \\
    raft::resources const& handle,                                \\
    const raft::neighbors::ivf_flat::index_params& params,        \\
    const T* dataset,                                             \\
    IdxT n_rows,                                                  \\
    uint32_t dim)                                                 \\
    ->raft::neighbors::ivf_flat::index<T, IdxT>;                  \\
                                                                  \\
  template auto raft::neighbors::ivf_flat::build<T, IdxT>(        \\
    raft::resources const& handle,                                \\
    const raft::neighbors::ivf_flat::index_params& params,        \\
    raft::device_matrix_view<const T, IdxT, row_major> dataset)   \\
    ->raft::neighbors::ivf_flat::index<T, IdxT>;                  \\
                                                                  \\
  template void raft::neighbors::ivf_flat::build<T, IdxT>(        \\
    raft::resources const& handle,                                \\
    const raft::neighbors::ivf_flat::index_params& params,        \\
    raft::device_matrix_view<const T, IdxT, row_major> dataset,   \\
    raft::neighbors::ivf_flat::index<T, IdxT>& idx);              \\
                                                                  \\
  template auto raft::neighbors::ivf_flat::build<T, IdxT>(        \\
    raft::resources const& handle,                                \\
    const raft::neighbors::ivf_flat::index_params& params,        \\
    raft::host_matrix_view<const T, IdxT, row_major> dataset)     \\
    ->raft::neighbors::ivf_flat::index<T, IdxT>;                  \\
                                                                  \\
  template void raft::neighbors::ivf_flat::build<T, IdxT>(        \\
    raft::resources const& handle,                                \\
    const raft::neighbors::ivf_flat::index_params& params,        \\
    raft::host_matrix_view<const T, IdxT, row_major> dataset,     \\
    raft::neighbors::ivf_flat::index<T, IdxT>& idx);
"""

extend_macro = """
#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT)                \\
  template auto raft::neighbors::ivf_flat::extend<T, IdxT>(                \\
    raft::resources const& handle,                                         \\
    const raft::neighbors::ivf_flat::index<T, IdxT>& orig_index,           \\
    const T* new_vectors,                                                  \\
    const IdxT* new_indices,                                               \\
    IdxT n_rows)                                                           \\
    ->raft::neighbors::ivf_flat::index<T, IdxT>;                           \\
                                                                           \\
  template auto raft::neighbors::ivf_flat::extend<T, IdxT>(                \\
    raft::resources const& handle,                                         \\
    raft::device_matrix_view<const T, IdxT, row_major> new_vectors,        \\
    std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \\
    const raft::neighbors::ivf_flat::index<T, IdxT>& orig_index)           \\
    ->raft::neighbors::ivf_flat::index<T, IdxT>;                           \\
                                                                           \\
  template void raft::neighbors::ivf_flat::extend<T, IdxT>(                \\
    raft::resources const& handle,                                         \\
    raft::neighbors::ivf_flat::index<T, IdxT>* index,                      \\
    const T* new_vectors,                                                  \\
    const IdxT* new_indices,                                               \\
    IdxT n_rows);                                                          \\
                                                                           \\
  template void raft::neighbors::ivf_flat::extend<T, IdxT>(                \\
    raft::resources const& handle,                                         \\
    raft::device_matrix_view<const T, IdxT, row_major> new_vectors,        \\
    std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \\
    raft::neighbors::ivf_flat::index<T, IdxT>* index);                     \\
                                                                           \\
   template auto raft::neighbors::ivf_flat::extend<T, IdxT>(               \\
    const raft::resources& handle,                                         \\
    raft::host_matrix_view<const T, IdxT, row_major> new_vectors,          \\
    std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices,  \\
    const raft::neighbors::ivf_flat::index<T, IdxT>& idx)                  \\
    -> raft::neighbors::ivf_flat::index<T, IdxT>;                          \\
                                                                           \\
   template void raft::neighbors::ivf_flat::extend<T, IdxT>(               \\
    raft::resources const& handle,                                         \\
    raft::host_matrix_view<const T, IdxT, row_major> new_vectors,          \\
    std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices,   \\
    raft::neighbors::ivf_flat::index<T, IdxT>* index);                     
"""

search_macro = """
#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT)        \\
  template void raft::neighbors::ivf_flat::search<T, IdxT>(        \\
    raft::resources const& handle,                                 \\
    const raft::neighbors::ivf_flat::search_params& params,        \\
    const raft::neighbors::ivf_flat::index<T, IdxT>& index,        \\
    const T* queries,                                              \\
    uint32_t n_queries,                                            \\
    uint32_t k,                                                    \\
    IdxT* neighbors,                                               \\
    float* distances,                                              \\
    rmm::device_async_resource_ref mr);                            \\
                                                                   \\
  template void raft::neighbors::ivf_flat::search<T, IdxT>(        \\
    raft::resources const& handle,                                 \\
    const raft::neighbors::ivf_flat::search_params& params,        \\
    const raft::neighbors::ivf_flat::index<T, IdxT>& index,        \\
    raft::device_matrix_view<const T, IdxT, row_major> queries,    \\
    raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,     \\
    raft::device_matrix_view<float, IdxT, row_major> distances);
"""

macros = dict(
    build=dict(
        definition=build_macro,
        name="instantiate_raft_neighbors_ivf_flat_build",
    ),
    extend=dict(
        definition=extend_macro,
        name="instantiate_raft_neighbors_ivf_flat_extend",
    ),
    search=dict(
        definition=search_macro,
        name="instantiate_raft_neighbors_ivf_flat_search",
    ),
)

for type_path, (T, IdxT) in types.items():
    for macro_path, macro in macros.items():
        path = f"ivf_flat_{macro_path}_{type_path}.cu"
        with open(path, "w") as f:
            f.write(header)
            f.write(macro["definition"])

            f.write(f"{macro['name']}({T}, {IdxT});\n\n")
            f.write(f"#undef {macro['name']}\n")

        print(f"src/neighbors/{path}")
